{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GrbQw8Glkddw"
      },
      "source": [
        "Copyright 2021 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "imz8IabSlEVM"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/google-research/google-research.git\n",
        "\n",
        "import sys\n",
        "sys.path.append('./google-research')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v-Ho7qCCW-E2"
      },
      "outputs": [],
      "source": [
        "# Verify credentials for BQ.\n",
        "\n",
        "from colabtools import auth\n",
        "from colabtools import bigquery\n",
        "\n",
        "from colabtools import adhoc_import\n",
        "from covid_epidemiology import colab_utils\n",
        "\n",
        "creds = auth.get_user_oauth2_credentials(bigquery.SCOPES)\n",
        "client = bigquery.Client(\n",
        "    project=colab_utils.constants.PROJECT_ID_MODEL_TRAINING, credentials=creds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S2NTjaZbKRNY"
      },
      "source": [
        "## Utils"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ESTgFZkyiBkx"
      },
      "source": [
        "### Prospective utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YocE3m3LiJAu"
      },
      "outputs": [],
      "source": [
        "import datetime\n",
        "\n",
        "prospective_start = datetime.datetime.strptime(\n",
        "    colab_utils.constants.PROSPECTIVE_START_DATE, '%Y-%m-%d').date()\n",
        "prospective_end = datetime.datetime.strptime(\n",
        "    colab_utils.constants.PROSPECTIVE_END_DATE, '%Y-%m-%d').date()\n",
        "\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import time\n",
        "def gather_prospective_data(\n",
        "    forecast_pd, pred_feature_key, gt_feature_keys, loc_key, locale, \n",
        "    use_latest_gt, debug=False):\n",
        "  \"\"\"Returns dataframe of date as key, dict of {dates, preds, gts} as value.\"\"\"\n",
        "  assert isinstance(gt_feature_keys, list)\n",
        "  assert locale in ('japan', 'state')\n",
        "  if locale == 'japan':\n",
        "    # Sanity check that we're using ex \"Hyogo\" and not \"Hyōgo\".\n",
        "    assert set(colab_utils.kaz_locations_to_open_covid_map().keys()).issubset(\n",
        "        set(forecast_pd.prefecture_name.unique()))\n",
        "\n",
        "  # Only use the relevant subset of forecasts.\n",
        "  forecast_pd = forecast_pd[forecast_pd.prediction_date \u003e forecast_pd.forecast_date]\n",
        "  forecast_pd = forecast_pd[[\"forecast_date\", \"prediction_date\", loc_key, pred_feature_key]]\n",
        "  assert forecast_pd[pred_feature_key].notna().all()\n",
        "\n",
        "  # Add GT version to forecast_pd.\n",
        "  min_version = {\n",
        "      'japan': '2020-10-07 20:01:32 UTC',\n",
        "      'state': None,\n",
        "  }[locale]\n",
        "  forecast_dates = list(forecast_pd.forecast_date.unique())\n",
        "  gt_versions = colab_utils.get_gt_version_names(\n",
        "      forecast_dates, locale, min_version, use_latest_gt=use_latest_gt,\n",
        "      client=client)\n",
        "  forecast_to_version_map = {\n",
        "      f: gtv for (f, gtv) in zip(forecast_dates, gt_versions)}\n",
        "  forecast_pd['gt_version'] = forecast_pd.forecast_date.replace(\n",
        "      forecast_to_version_map)\n",
        "  print('Got all GT versions.')\n",
        "  if debug:\n",
        "    for k, v in forecast_to_version_map.items():\n",
        "      print(f'{k}: {v}')\n",
        "\n",
        "  # Get ground truth forecasts.\n",
        "  gt_versions_to_use = forecast_pd.gt_version.unique().tolist()\n",
        "  if locale == 'japan' and not use_latest_gt:\n",
        "    # Add an extra buffer at the end so automatic increment can work.\n",
        "    gt_versions_to_use += colab_utils.get_gt_version_names(\n",
        "      [max(forecast_dates) + datetime.timedelta(days=i) for i in range(1, 10)], \n",
        "      locale, \n",
        "      min_version, \n",
        "      use_latest_gt=use_latest_gt,\n",
        "      client=client)\n",
        "  gt_df = colab_utils.get_all_gt(\n",
        "      # Get the GT for \"day 0\" in case we want to compute incident cases/deaths,\n",
        "      # which is always computed as a delta from GT at day 0. So we get 2 days\n",
        "      # before the first prediction date, instead of 1 day.\n",
        "      min(forecast_pd.prediction_date) - datetime.timedelta(days=2),\n",
        "      # Add a date buffer at the end in case we need to auto-increment.\n",
        "      max(forecast_pd.prediction_date),\n",
        "      locale=locale,\n",
        "      bq_client=client,\n",
        "      version=gt_versions_to_use,\n",
        "      feature_keys=gt_feature_keys)\n",
        "  print('Finished reading GT.')\n",
        "\n",
        "  # Check that versions are subset of available versions..\n",
        "  available_versions = gt_df.version.dt.strftime(\n",
        "      '%Y-%m-%d %H:%M:%S+00:00')\n",
        "  assert forecast_pd.gt_version.isin(available_versions).all()\n",
        "  \n",
        "  # Get predictions and GT in one row.\n",
        "  start_time = time.time()\n",
        "  forecast_pd.rename(columns={pred_feature_key: 'predictions'}, inplace=True)\n",
        "  gathered_data_df = forecast_pd.groupby(\n",
        "      ['forecast_date', loc_key]).apply(\n",
        "          colab_utils.gather_data_from_prospective_row, \n",
        "          gt_df=gt_df,\n",
        "          locale=locale,\n",
        "          available_versions=np.unique(available_versions.values),\n",
        "          debug=debug)\n",
        "  gathered_data_df = gathered_data_df.to_frame().reset_index().convert_dtypes()\n",
        "  gathered_data_df['forecast_date'] = gathered_data_df.forecast_date.astype(str)\n",
        "  gathered_data_df.rename(columns={loc_key: 'location_name'}, inplace=True)\n",
        "  end_time = time.time()\n",
        "  total_time_min = (end_time - start_time) / 60.0\n",
        "  print(f'Total time: {total_time_min:.2f} min')\n",
        "  print(f'Time per date: {total_time_min / len(gathered_data_df):.2f} min / item')\n",
        "\n",
        "  return gathered_data_df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kHguaO4-n_6d"
      },
      "source": [
        "### Plotting and filtering utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XqmSjAMposeJ"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "import matplotlib\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "def plot_events(events, d_min, max_y, line_top=0.5):\n",
        "  for i, (d, name, ybuff, xbuff) in enumerate(events):\n",
        "    plt.axvline(d, ymin=0.0, ymax=line_top, color='r')\n",
        "    if d \u003e d_min:\n",
        "      plt.annotate(\n",
        "          name, \n",
        "          xy=(d, line_top * max_y), \n",
        "          fontsize=16,\n",
        "          xytext=(d + datetime.timedelta(days=xbuff), (line_top + ybuff) * max_y),\n",
        "          arrowprops={'arrowstyle': '-\u003e'}\n",
        "          )\n",
        "\n",
        "def calculate_mape(base_pd, calculate_mape_apply_fn_args):\n",
        "  base = base_pd.groupby('forecast_date').apply(\n",
        "      colab_utils.calculate_mape_apply_fn, **calculate_mape_apply_fn_args).dropna()\n",
        "  base_xs = [datetime.datetime.strptime(x, '%Y-%m-%d').date() for x in base.index]\n",
        "  base_ys = base.values\n",
        "  return base_xs, base_ys\n",
        "\n",
        "def write_to_cns(xs, ys, name):\n",
        "  colab_utils.write_csv_to_cns(\n",
        "    data_dict={'forecast_date': xs, 'pred': ys}, graph_name=name)\n",
        "  \n",
        "def print_mape_stats(avg_type, locale, prosp_confirmed_ys, prosp_deaths_ys):\n",
        "  assert locale in ['Japan', 'US State']\n",
        "  def _print(dat, window, metric, ignore_nan):\n",
        "    if dat is None: return\n",
        "    m, l, u = colab_utils.mean_confidence_interval(\n",
        "        dat, confidence=0.95, ignore_nan=ignore_nan)\n",
        "    assert m == np.nanmean(dat)\n",
        "    print(f'{avg_type}_{locale}_{window}_{metric}_MAPE: {m:.2f} [{l:.2f}, {u:.2f}]')\n",
        "  _print(prosp_confirmed_ys, 'prospective', 'confirmed', ignore_nan=False)\n",
        "  _print(prosp_deaths_ys, 'prospective', 'deaths', ignore_nan=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9vzRAZivvyXc"
      },
      "source": [
        "## Plots"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y9Jlk1jMas1C"
      },
      "source": [
        "### Japan"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TLmmDHPkDpOO"
      },
      "source": [
        "#### Load prospective data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TqFFaLv5wOMF"
      },
      "outputs": [],
      "source": [
        "# Load prospective Japan data.\n",
        "# Only look at ground truth values that we'll care about ie in the prospective \n",
        "# period.\n",
        "bq_table = 'bigquery-public-data.covid19_public_forecasts.japan_prefecture_28d_historical'\n",
        "q = (f\"select * from `{bq_table}`\"\n",
        "     f\"where forecast_date \u003e= '{prospective_start.isoformat()}' \"\n",
        "     f\"and forecast_date \u003c= '{prospective_end.isoformat()}'\")\n",
        "jp_forecast_pd = client.query(q).to_dataframe()\n",
        "\n",
        "jp_prosp_confirmed_df = gather_prospective_data(\n",
        "    jp_forecast_pd, \n",
        "    pred_feature_key='cumulative_confirmed', \n",
        "    gt_feature_keys=['kaz_confirmed_cases', 'open_gt_jp_confirmed_cases'],\n",
        "    loc_key='prefecture_name', \n",
        "    locale='japan',\n",
        "    use_latest_gt=False)\n",
        "jp_prosp_deaths_df = gather_prospective_data(\n",
        "    jp_forecast_pd, \n",
        "    pred_feature_key='cumulative_deaths', \n",
        "    gt_feature_keys=['kaz_deaths', 'open_gt_jp_deaths'],\n",
        "    loc_key='prefecture_name', \n",
        "    locale='japan',\n",
        "    use_latest_gt=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Oehw--RqJQ7l"
      },
      "source": [
        "#### Plot micro average."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_9dQN0v7JTH3"
      },
      "outputs": [],
      "source": [
        "# Micro average.\n",
        "calculate_mape_apply_fn_args = {\n",
        "    'average_type': 'micro', \n",
        "    'expected_num_locations': 47,\n",
        "    'min_count': None,\n",
        "    'min_mae': None,\n",
        "    'value_type': '4week',\n",
        "}\n",
        "\n",
        "# Prospective.\n",
        "jp_prosp_confirmed_xs, jp_prosp_confirmed_ys = calculate_mape(\n",
        "    jp_prosp_confirmed_df, calculate_mape_apply_fn_args)\n",
        "jp_prosp_deaths_xs, jp_prosp_deaths_ys = calculate_mape(\n",
        "    jp_prosp_deaths_df, calculate_mape_apply_fn_args)\n",
        "\n",
        "plt.figure(figsize=(16, 6))\n",
        "# Plot prospective.\n",
        "plt.plot(jp_prosp_confirmed_xs, jp_prosp_confirmed_ys, label='confirmed, prospective', \n",
        "         marker='o', color='b')\n",
        "plt.plot(jp_prosp_deaths_xs, jp_prosp_deaths_ys, label='deaths, prospective', \n",
        "          marker='o', color='g')\n",
        "\n",
        "# Formatting.\n",
        "plt.legend()\n",
        "plt.grid()\n",
        "plt.xlabel('train date')\n",
        "plt.ylabel('MAPE')\n",
        "plt.title('MAPE vs training date, micro average')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CmSz_ARsP_Bk"
      },
      "outputs": [],
      "source": [
        "print_mape_stats(\n",
        "    avg_type='Micro average', \n",
        "    locale='Japan', \n",
        "    prosp_confirmed_ys=jp_prosp_confirmed_ys, \n",
        "    prosp_deaths_ys=jp_prosp_deaths_ys)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WbWRb5awDuUH"
      },
      "source": [
        "#### Plot macro average."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K8F7a2irtRo1"
      },
      "outputs": [],
      "source": [
        "# Macro average.\n",
        "calculate_mape_apply_fn_args = {\n",
        "    'average_type': 'macro', \n",
        "    'expected_num_locations': 47,\n",
        "    'min_count': None, \n",
        "    'min_mae': None,\n",
        "    'value_type': '4week',\n",
        "}\n",
        "\n",
        "# Prospective.\n",
        "jp_prosp_confirmed_xs, jp_prosp_confirmed_ys = calculate_mape(\n",
        "    jp_prosp_confirmed_df, calculate_mape_apply_fn_args)\n",
        "jp_prosp_deaths_xs, jp_prosp_deaths_ys = calculate_mape(\n",
        "    jp_prosp_deaths_df, calculate_mape_apply_fn_args)\n",
        "\n",
        "plt.figure(figsize=(16, 6))\n",
        "# Plot prospective.\n",
        "plt.plot(jp_prosp_confirmed_xs, jp_prosp_confirmed_ys, label='confirmed, prospective', \n",
        "         marker='o', color='b')\n",
        "plt.plot(jp_prosp_deaths_xs, jp_prosp_deaths_ys, label='deaths, prospective', \n",
        "          marker='o', color='g')\n",
        "\n",
        "# Formatting.\n",
        "plt.legend()\n",
        "plt.grid()\n",
        "plt.xlabel('train date')\n",
        "plt.ylabel('MAPE')\n",
        "plt.title('MAPE vs training date, macro average')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b82FQUXZg802"
      },
      "outputs": [],
      "source": [
        "print_mape_stats(\n",
        "    avg_type='Macro average', \n",
        "    locale='Japan', \n",
        "    prosp_confirmed_ys=jp_prosp_confirmed_ys, \n",
        "    prosp_deaths_ys=jp_prosp_deaths_ys)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z_Fy98pyDxa3"
      },
      "source": [
        "#### Plot macro average, with min count."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o7_EfGc5_WW7"
      },
      "outputs": [],
      "source": [
        "# Macro average, with min count.\n",
        "\n",
        "calculate_mape_apply_fn_args = {\n",
        "    'average_type': 'macro', \n",
        "    'expected_num_locations': 47,\n",
        "    'min_mae': None,\n",
        "    'value_type': '4week',\n",
        "}\n",
        "\n",
        "# Prospective.\n",
        "calculate_mape_apply_fn_args['min_count'] = 1000\n",
        "jp_prosp_confirmed_xs, jp_prosp_confirmed_ys = calculate_mape(\n",
        "    jp_prosp_confirmed_df, calculate_mape_apply_fn_args)\n",
        "calculate_mape_apply_fn_args['min_count'] = 10\n",
        "jp_prosp_deaths_xs, jp_prosp_deaths_ys = calculate_mape(\n",
        "    jp_prosp_deaths_df, calculate_mape_apply_fn_args)\n",
        "\n",
        "plt.figure(figsize=(16, 6))\n",
        "# Plot prospective.\n",
        "plt.plot(jp_prosp_confirmed_xs, jp_prosp_confirmed_ys, label='confirmed, prospective', \n",
        "         marker='o', color='b')\n",
        "plt.plot(jp_prosp_deaths_xs, jp_prosp_deaths_ys, label='deaths, prospective', \n",
        "          marker='o', color='g')\n",
        "\n",
        "# Formatting.\n",
        "plt.grid()\n",
        "plt.legend()\n",
        "plt.xlabel('train date')\n",
        "plt.ylabel('MAPE')\n",
        "plt.title('MAPE vs training date, macro average (\u003e=1K cases or 10 deaths)')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1pgHlje6hEk4"
      },
      "outputs": [],
      "source": [
        "print_mape_stats(\n",
        "    avg_type='Macro average, with min count', \n",
        "    locale='Japan', \n",
        "    prosp_confirmed_ys=jp_prosp_confirmed_ys, \n",
        "    prosp_deaths_ys=jp_prosp_deaths_ys)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AFtqhuLEDzck"
      },
      "source": [
        "#### Plot macro average, with min mae."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f7nEzeqhAIlf"
      },
      "outputs": [],
      "source": [
        "# Macro average, with min mae.\n",
        "\n",
        "calculate_mape_apply_fn_args = {\n",
        "    'average_type': 'macro', \n",
        "    'expected_num_locations': 47,\n",
        "    'min_mae': 10.0,\n",
        "    'min_count': None,\n",
        "    'value_type': 'daily',\n",
        "}\n",
        "\n",
        "# Prospective.\n",
        "jp_prosp_confirmed_xs, jp_prosp_confirmed_ys = calculate_mape(\n",
        "    jp_prosp_confirmed_df, calculate_mape_apply_fn_args)\n",
        "jp_prosp_deaths_xs, jp_prosp_deaths_ys = calculate_mape(\n",
        "    jp_prosp_deaths_df, calculate_mape_apply_fn_args)\n",
        "\n",
        "plt.figure(figsize=(16, 6))\n",
        "# Plot prospective.\n",
        "plt.plot(jp_prosp_confirmed_xs, jp_prosp_confirmed_ys, label='confirmed, prospective', \n",
        "         marker='o', color='b')\n",
        "plt.plot(jp_prosp_deaths_xs, jp_prosp_deaths_ys, label='deaths, prospective', \n",
        "          marker='o', color='g')\n",
        "\n",
        "# Formatting.\n",
        "plt.grid()\n",
        "plt.legend()\n",
        "plt.xlabel('train date')\n",
        "plt.ylabel('MAPE')\n",
        "plt.title(f'MAPE vs training date (\u003e= {calculate_mape_apply_fn_args[\"min_mae\"]} MAE)')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xB0pgwiChHfk"
      },
      "outputs": [],
      "source": [
        "print_mape_stats(\n",
        "    avg_type='Macro average, with mae threshold', \n",
        "    locale='Japan', \n",
        "    prosp_confirmed_ys=jp_prosp_confirmed_ys, \n",
        "    prosp_deaths_ys=jp_prosp_deaths_ys)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_rmjrS8vvVUG"
      },
      "source": [
        "### US"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x9LBrVZnD6pI"
      },
      "source": [
        "#### Load prospective data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IY3cO03ollBW"
      },
      "outputs": [],
      "source": [
        "# Load data from BQ.\n",
        "# Only look at ground truth values that we'll care about ie in the prospective \n",
        "# period.\n",
        "bq_table = 'bigquery-public-data.covid19_public_forecasts.state_28d_historical'\n",
        "q = (f\"select * from `{bq_table}`\"\n",
        "     f\"where forecast_date \u003e= '{prospective_start.isoformat()}' \"\n",
        "     f\"and forecast_date \u003c= '{prospective_end.isoformat()}'\")\n",
        "us_forecast_pd = client.query(q).to_dataframe()\n",
        "\n",
        "# The US has double Hawaii, which we should remove.\n",
        "with_double_hawaii = len(us_forecast_pd)\n",
        "us_forecast_pd_dedupped = us_forecast_pd.drop_duplicates(\n",
        "    subset=['state_fips_code', 'state_name', 'prediction_date', 'forecast_date'])\n",
        "dups_dropped = with_double_hawaii - len(us_forecast_pd_dedupped)\n",
        "print(f'dropped {dups_dropped} rows...')\n",
        "\n",
        "us_prosp_confirmed_df = gather_prospective_data(\n",
        "    us_forecast_pd_dedupped, \n",
        "    pred_feature_key='cumulative_confirmed', \n",
        "    gt_feature_keys=['jhu_state_confirmed_cases'],\n",
        "    loc_key='state_name', \n",
        "    locale='state',\n",
        "    use_latest_gt=False)\n",
        "us_prosp_deaths_df = gather_prospective_data(\n",
        "    us_forecast_pd_dedupped, \n",
        "    pred_feature_key='cumulative_deaths', \n",
        "    gt_feature_keys=['jhu_state_deaths'],\n",
        "    loc_key='state_name', \n",
        "    locale='state',\n",
        "    use_latest_gt=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6u52Ts_3EB-F"
      },
      "source": [
        "#### Plot micro average.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Dh_2g6bG9Vy"
      },
      "outputs": [],
      "source": [
        "# Micro average.\n",
        "\n",
        "calculate_mape_apply_fn_args = {\n",
        "    'average_type': 'micro', \n",
        "    'expected_num_locations': 51,\n",
        "    'min_mae': None,\n",
        "    'min_count': None,\n",
        "    'value_type': '4week',\n",
        "}\n",
        "\n",
        "# Prospective.\n",
        "us_prosp_confirmed_xs, us_prosp_confirmed_ys = calculate_mape(\n",
        "    us_prosp_confirmed_df, calculate_mape_apply_fn_args)\n",
        "us_prosp_deaths_xs, us_prosp_deaths_ys = calculate_mape(\n",
        "    us_prosp_deaths_df, calculate_mape_apply_fn_args)\n",
        "\n",
        "plt.figure(figsize=(16, 6))\n",
        "# Plot prospective.\n",
        "plt.plot(us_prosp_confirmed_xs, us_prosp_confirmed_ys, label='confirmed, prospective', \n",
        "         marker='o', color='b')\n",
        "plt.plot(us_prosp_deaths_xs, us_prosp_deaths_ys, label='deaths, prospective', \n",
        "          marker='o', color='g')\n",
        "\n",
        "# Formatting.\n",
        "plt.grid()\n",
        "plt.legend()\n",
        "plt.xlabel('train date')\n",
        "plt.ylabel('MAPE')\n",
        "plt.title(f'MAPE vs training date, Micro average')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EET4zfCSN-gL"
      },
      "outputs": [],
      "source": [
        "print_mape_stats(\n",
        "    avg_type='Micro average', \n",
        "    locale='US State', \n",
        "    prosp_confirmed_ys=us_prosp_confirmed_ys, \n",
        "    prosp_deaths_ys=us_prosp_deaths_ys)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_ovvYhnyEIe5"
      },
      "source": [
        "#### Plot macro average."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vozAJXZhFBN4"
      },
      "outputs": [],
      "source": [
        "# Macro average.\n",
        "calculate_mape_apply_fn_args = {\n",
        "    'average_type': 'macro', \n",
        "    'expected_num_locations': 51,\n",
        "    'min_mae': None,\n",
        "    'min_count': None,\n",
        "    'value_type': '4week',\n",
        "}\n",
        "\n",
        "# Prospective.\n",
        "us_prosp_confirmed_xs, us_prosp_confirmed_ys = calculate_mape(\n",
        "    us_prosp_confirmed_df, calculate_mape_apply_fn_args)\n",
        "us_prosp_deaths_xs, us_prosp_deaths_ys = calculate_mape(\n",
        "    us_prosp_deaths_df, calculate_mape_apply_fn_args)\n",
        "\n",
        "plt.figure(figsize=(16, 6))\n",
        "# Plot prospective.\n",
        "plt.plot(us_prosp_confirmed_xs, us_prosp_confirmed_ys, label='confirmed, prospective', \n",
        "         marker='o', color='b')\n",
        "plt.plot(us_prosp_deaths_xs, us_prosp_deaths_ys, label='deaths, prospective', \n",
        "          marker='o', color='g')\n",
        "\n",
        "# Formatting.\n",
        "plt.grid()\n",
        "plt.legend()\n",
        "plt.xlabel('train date')\n",
        "plt.ylabel('MAPE')\n",
        "plt.title(f'MAPE vs training date, macro average')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gJwTZC5fipnr"
      },
      "outputs": [],
      "source": [
        "print_mape_stats(\n",
        "    avg_type='Macro average',\n",
        "    locale='US State', \n",
        "    prosp_confirmed_ys=us_prosp_confirmed_ys, \n",
        "    prosp_deaths_ys=us_prosp_deaths_ys)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TwBl1dc_EKP-"
      },
      "source": [
        "#### Plot macro average, with min mae."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yqJCVhmzEV0j"
      },
      "outputs": [],
      "source": [
        "# Macro average, with min mae.\n",
        "\n",
        "calculate_mape_apply_fn_args = {\n",
        "    'average_type': 'macro', \n",
        "    'expected_num_locations': 51,\n",
        "    'min_mae': 100.0,\n",
        "    'min_count': None,\n",
        "    'value_type': '4week',\n",
        "}\n",
        "\n",
        "# Prospective.\n",
        "us_prosp_confirmed_xs, us_prosp_confirmed_ys = calculate_mape(\n",
        "    us_prosp_confirmed_df, calculate_mape_apply_fn_args)\n",
        "us_prosp_deaths_xs, us_prosp_deaths_ys = calculate_mape(\n",
        "    us_prosp_deaths_df, calculate_mape_apply_fn_args)\n",
        "\n",
        "plt.figure(figsize=(16, 6))\n",
        "# Plot prospective.\n",
        "plt.plot(us_prosp_confirmed_xs, us_prosp_confirmed_ys, label='confirmed, prospective', \n",
        "         marker='o', color='b')\n",
        "plt.plot(us_prosp_deaths_xs, us_prosp_deaths_ys, label='deaths, prospective', \n",
        "          marker='o', color='g')\n",
        "\n",
        "# Formatting.\n",
        "plt.grid()\n",
        "plt.legend()\n",
        "plt.xlabel('train date')\n",
        "plt.ylabel('MAPE')\n",
        "plt.title(f'MAPE vs training date (\u003e= {calculate_mape_apply_fn_args[\"min_mae\"]} MAE)')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xNDe2uzviyLj"
      },
      "outputs": [],
      "source": [
        "print_mape_stats(\n",
        "    avg_type='Macro average, with min mae threshold',\n",
        "    locale='US State', \n",
        "    prosp_confirmed_ys=us_prosp_confirmed_ys, \n",
        "    prosp_deaths_ys=us_prosp_deaths_ys)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "prospective metrics",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1kunYvKhD6DfUywmGvJnbQGhnSC7m-C2u",
          "timestamp": 1606442291564
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
